from lmse import *
from common import *
import numpy as np


if __name__ == '__main__':
    data = read_data_by_labels('TrainSamples.csv', 'TrainLabels.csv')
    lmses = [LMSE() for i in range(10)]
    for label in range(10):
        pos_samples = data[label]
        neg_samples = []
        for i in range(10):
            if i != label:
                samples = data[i]
                for sample in samples:
                    new_sample = sample + []
                    for j in range(len(new_sample)):
                        new_sample[j] = -(new_sample[j])
                    neg_samples.append(new_sample)
        samples = np.array(pos_samples + neg_samples)
        print('Training LMSE ' + str(label))
        lmses[label].train(samples, np.array([1 for i in range(len(samples))]))

    # Train statistics:
    # print()
    # print('Train statistics')
    # num_correct_all_labels = 0
    # num_total_all_labels = 0
    # for label in range(10):
    #     num_correct = 0
    #     samples = data[label]
    #     for sample in samples:
    #         scores = [lmse.classify(sample) for lmse in lmses]
    #         pred = np.argmax(scores)
    #         if pred == label:
    #             num_correct += 1
    #     num_total = len(samples)
    #     accuracy = num_correct / num_total
    #     print('Label ' + str(label) + ', num_correct = ' + str(num_correct) + ', num_total = ' + str(num_total) + ', accuracy = ' + str(accuracy))
    #     num_correct_all_labels += num_correct
    #     num_total_all_labels += num_total
    # accuracy = num_correct_all_labels / num_total_all_labels
    # print('All labels, num_correct = ' + str(num_correct_all_labels) + ', num_total = ' + str(num_total_all_labels) + ', accuracy = ' + str(accuracy))

    # Label 0, num_correct = 2728, num_total = 2975, accuracy = 0.9169747899159664
    # Label 1, num_correct = 3293, num_total = 3419, accuracy = 0.9631471190406552
    # Label 2, num_correct = 2204, num_total = 2984, accuracy = 0.7386058981233244
    # Label 3, num_correct = 2375, num_total = 3030, accuracy = 0.7838283828382838
    # Label 4, num_correct = 2537, num_total = 2964, accuracy = 0.8559379217273954
    # Label 5, num_correct = 1105, num_total = 2710, accuracy = 0.4077490774907749
    # Label 6, num_correct = 2611, num_total = 2950, accuracy = 0.8850847457627119
    # Label 7, num_correct = 2784, num_total = 3164, accuracy = 0.8798988621997471
    # Label 8, num_correct = 1938, num_total = 2900, accuracy = 0.6682758620689655
    # Label 9, num_correct = 1223, num_total = 2904, accuracy = 0.42114325068870523
    # All labels, num_correct = 22798, num_total = 30000, accuracy = 0.7599333333333333

    # Test statistics
    test_data = read_data_by_labels('TestSamples.csv', 'TestLabels.csv')
    print()
    print('Test statistics')
    num_correct_all_labels = 0
    num_total_all_labels = 0
    for label in range(10):
        num_correct = 0
        samples = test_data[label]
        for sample in samples:
            scores = [lmse.classify(sample) for lmse in lmses]
            pred = np.argmax(scores)
            if pred == label:
                num_correct += 1
        num_total = len(samples)
        accuracy = num_correct / num_total
        print('Label ' + str(label) + ', num_correct = ' + str(num_correct) + ', num_total = ' + str(num_total) + ', accuracy = ' + str(accuracy))
        num_correct_all_labels += num_correct
        num_total_all_labels += num_total
    accuracy = num_correct_all_labels / num_total_all_labels
    print('All labels, num_correct = ' + str(num_correct_all_labels) + ', num_total = ' + str(num_total_all_labels) + ', accuracy = ' + str(accuracy))

    # Label 0, num_correct = 927, num_total = 1009, accuracy = 0.9187314172447968
    # Label 1, num_correct = 1110, num_total = 1150, accuracy = 0.9652173913043478
    # Label 2, num_correct = 689, num_total = 963, accuracy = 0.715472481827622
    # Label 3, num_correct = 800, num_total = 1021, accuracy = 0.7835455435847208
    # Label 4, num_correct = 806, num_total = 960, accuracy = 0.8395833333333333
    # Label 5, num_correct = 388, num_total = 917, accuracy = 0.42311886586695746
    # Label 6, num_correct = 866, num_total = 974, accuracy = 0.8891170431211499
    # Label 7, num_correct = 912, num_total = 1047, accuracy = 0.8710601719197708
    # Label 8, num_correct = 663, num_total = 965, accuracy = 0.6870466321243524
    # Label 9, num_correct = 391, num_total = 994, accuracy = 0.3933601609657948
    # All labels, num_correct = 7552, num_total = 10000, accuracy = 0.7552
